{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating your own optical elements\n",
    "\n",
    "We will present the internal workings of optical elements, and show how you can create your own optical elements in HCIPy.\n",
    "\n",
    "## Standard optical elements\n",
    "\n",
    "We'll start by importing all the necessary packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hcipy import *\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import traceback"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An `OpticalElement` is an object that can propagate a `Wavefront` from one plane to another. This includes apodizers, deformable mirrors, focal-plane masks, coronagraphs or even complete optical systems. An optical element should in principle be able to handle any type of `Wavefront` passed to its `forward()` and `backward()` functions. This puts a large responsibility on the implementation of the optical element.\n",
    "\n",
    "Any proper implementation of an optical element should therefore implement a number of functions to help signal its intent. Here, we are going to implement a very simple `OpticalElement`: a neutral-density filter. We'll show the full implementation first, and then go in to the details for each implemented function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NeutralDensityFilter(OpticalElement):\n",
    "    def __init__(self, transmittance):\n",
    "        self.transmittance = transmittance\n",
    "    \n",
    "    def forward(self, wavefront):\n",
    "        wf = wavefront.copy()\n",
    "        \n",
    "        wf.electric_field *= np.sqrt(self.transmittance)\n",
    "        \n",
    "        return wf\n",
    "\n",
    "    def backward(self, wavefront):\n",
    "        wf = Wavefront.copy()\n",
    "        \n",
    "        wf.electric_field *= np.sqrt(self.transmittance)\n",
    "        \n",
    "        return wf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A neutral-density filter is a filter that reduces the intensity of the transmitted light independent of wavelength or spatial position. Therefore, the above implementation can handle any incoming wavefront. The initializer of the class just stores the passed transmittance. In the initializer, usually we prefer to do as much computation as possible, to aliviate the burden on the `forward()` and `backward()` functions. \n",
    "\n",
    "The `forward()` function propagates the wavefront through the filter. This function first creates a copy of the wavefront to work on, then modifies its electric field and returns it. Any implementation of `forward()` should not attempt to operate on a wavefront in-place and should always return a new `Wavefront` object.\n",
    "\n",
    "The `backward()` function might need some extra explanation. It does not implement the propagation in the opposite direction through the optical element, or even the inverse or pseudoinverse of the forward propagation, but rather the adjoint. In absense of non-linear effects, any optical element can be thought of as a linear operator on a complex Hilbert space. This Hilbert space is the space of all possible wavefronts, and an optical element acts on a wavefront to produce another wavefront in the same space. This means that we can also construct the Hermitan adjoint of this operator. The `backward()` function implements this Hermitian adjoint propagation. In some cases, ie. when energy is conserved in the optical element, the `backward()` function is the inverse of the `forward()` function, but *in general this is not the case*."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agnostic optical elements\n",
    "\n",
    "### Motivation\n",
    "\n",
    "Sometimes we want more complicated behaviour. For instance, we might want the transmission of the optic to change as function of wavelength, or spatial position in the plane. As a case study, we take here an `ColourFilter` class, which implements the transmission through a colour filter. A naive implementation would be something along the lines of the following."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NaiveColourFilter(OpticalElement):\n",
    "    def __init__(self, filter_transmission):\n",
    "        self.filter_transmission = filter_transmission\n",
    "    \n",
    "    def forward(self, wavefront):\n",
    "        wf = wavefront.copy()\n",
    "        \n",
    "        transmission = self.filter_transmission(wavefront.wavelength)\n",
    "        wf.electric_field *= np.sqrt(transmission)\n",
    "        \n",
    "        return wf\n",
    "    \n",
    "    def backward(self, wavefront):\n",
    "        wf = wavefront.copy()\n",
    "        \n",
    "        transmission = self.filter_transmission(wavefront.wavelength)\n",
    "        wf.electric_field *= np.sqrt(transmission)\n",
    "        \n",
    "        return wf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We would initialize this class with a transmission which is a function of wavelength."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "central_wavelength = 700e-9 # m\n",
    "spectral_bandwidth = 100e-9 # m\n",
    "\n",
    "def block_filter(wavelength):\n",
    "    lower = central_wavelength - spectral_bandwidth / 2\n",
    "    upper = central_wavelength + spectral_bandwidth / 2\n",
    "    \n",
    "    return (lower < wavelength) and (wavelength < upper)\n",
    "\n",
    "colour_filter = NaiveColourFilter(block_filter)\n",
    "\n",
    "pupil_grid = make_pupil_grid(128)\n",
    "wf = Wavefront(pupil_grid.ones())\n",
    "post_filter = colour_filter(wf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, this implementation doesn't support a float as the filter transmission, as a float cannot be called with a wavelength."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    colour_filter = NaiveColourFilter(0.7)\n",
    "    post_ilter = colour_filter(wf)\n",
    "except Exception as e:\n",
    "    traceback.print_exc()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This narrow behaviour is fine for quick development, that is, in cases where we expect ourselves to be the only ones using the optical element and/or there are only a few use cases when we would want to use it. For more general optical elements, however, we would prefer the objects to be able to handle a wide variety of input arguments.\n",
    "\n",
    "Take for example the `LinearRetarder` optical element. The apodization pass as the argument to its initializer/constructor can be\n",
    "1. a scalar,\n",
    "2. a `Field`,\n",
    "3. a `Field` generator,\n",
    "4. a function of wavelength, returning either a `Field` or scalar, \n",
    "5. a function of both a `Grid` and a wavelength, returning a `Field`.\n",
    "\n",
    "If they  apodization will be evaluated at the `Grid` and wavelength of the incoming `Wavefront`. All of this is done seamlessly and is hidden from view for the user. Another example is that of the `LinearRetarder`, which should be able to handle (at least) a phase retardance which is either constant with wavelength (ie. a scalar) or a function of wavelength. Manually distinguising between the above cases can become quite cumbersome, so HCIPy offers a simpler solution.\n",
    "\n",
    "### Implementation details\n",
    "\n",
    "Let's look at the implementation of `Apodizer` as an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Apodizer(AgnosticOpticalElement):\n",
    "    def __init__(self, apodization):\n",
    "        self.apodization = apodization\n",
    "\n",
    "        AgnosticOpticalElement.__init__(self, grid_dependent=True, wavelength_dependent=True, max_in_cache=11)\n",
    "\n",
    "    def make_instance(self, instance_data, input_grid, output_grid, wavelength):\n",
    "        instance_data.apodization = self.evaluate_parameter(self.apodization, input_grid, output_grid, wavelength)\n",
    "    \n",
    "    @property\n",
    "    def apodization(self):\n",
    "        return self._apodization\n",
    "\n",
    "    @apodization.setter\n",
    "    def apodization(self, apodization):\n",
    "        self._apodization = apodization\n",
    "\n",
    "        self.clear_cache()\n",
    "\n",
    "    def get_input_grid(self, output_grid, wavelength):\n",
    "        return output_grid\n",
    "\n",
    "    def get_output_grid(self, input_grid, wavelength):\n",
    "        return input_grid\n",
    "\n",
    "    @make_agnostic_forward\n",
    "    def forward(self, instance_data, wavefront):\n",
    "        wf = wavefront.copy()\n",
    "        wf.electric_field *= instance_data.apodization\n",
    "\n",
    "        return wf\n",
    "\n",
    "    @make_agnostic_backward\n",
    "    def backward(self, instance_data, wavefront):\n",
    "        wf = wavefront.copy()\n",
    "        wf.electric_field *= np.conj(instance_data.apodization)\n",
    "\n",
    "        return wf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `Apodizer` class is derived from `AgnosticOpticalElement` rather than `OpticalElement`. The `AgnosticOpticalElement` class provides the ability of an optical element to more easily work correctly for a wide range of inputs. It allows the optical element to be written mostly expecting a single input `Grid` and wavelength. The data required for each set of grid and wavelength is cached, so expensive calculations are not redone each time a `Wavefront` is propagated through. \n",
    "\n",
    "Setup code for an `AgnosticOpticalElement` is done in two different functions. The first is the initializer `__init__()`. Anything done in this function should not depend on wavelength or input grid. The above implmenetation just stores the value of `apodization` whatever that may be, and calls the `__init__()` of the parent class. This initializer has a few parameters, indicating if the `AgnosticOpticalELlement` class should create new instances when a `Wavefront` differs in `Grid` or in wavelength. For example, if our optical element is achromatic, but can still depend on the input grid, we would set `wavelength_dependent` to `False`. The last argument is the maximum number of instances to keep in the internal cache. If the class stores a lot of data per instance, you might run out of memory if too many instances are kept in memory. This value can be reduced if necessary (or increased if the set up time per instance is high, but memory usage is small).\n",
    "\n",
    "The second initializer is `make_instance()`. This function should calculate any data that depends on input grid, output grid and/or wavelength. Any calculated data should be stored in `instance_data` which is the object that is stored in the cache. In the implmementation above, we evalute the `apodization` at the supplied input grid, output grid and wavelength. This is done with the `evaluate_parameter()` function, which does all the heavy lifting with figuring out what type of parameter it actually is (ie. is it dependent on grid, or wavelength, or both, or none), and evaluates with the appropriate signature.\n",
    "\n",
    "The `apodization` is implemented in the class using a `property`. The reason for this is that we want to clear the cache when we change the `apodization`. If we wouldn't do this, we would still use the old apodization during propagation on a grid and wavelength that we have seen before, and were in the cache already.\n",
    "\n",
    "The next two functions let the `AgnosticOpticaElement` know what the corresponding input grid is, given an output grid and wavelength, and vice versa. This allows backwards propagation through an optical element, without having done a forward propagation first. (The exact reason for this is a bit harder to explain. When doing a backward propagation, it needs to look up the instance data in the cache. At that point, only the output grid is available, while the cache would have `instance_data`s listed by input grid. In this case it is necessary to know the input grid corresponding to previously calculated instances.)\n",
    "\n",
    "Finally, we have the `forward()` and `backward()` functions. These are implemened with the use of the `make_agnostic_forward()` and `make_agnostic_backward()` decorators, which look up the correct `instance_data` corresponding to the supplied `Wavefront` and call the `forward()` or `backward()` functions with these as well. Then, inside the propagation functions themselves, you have access to the evaluated parameters via `instance_data`. In the case of the `Apodizer`, we have access to the `instance_data.apodization`, and we propagate through the apodizer in the usual way."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adding properties based on other properties\n",
    "\n",
    "Suppose now that we would want to know the phase of the apodization in the apodizer. If we would have the apodization as a `Field` or a scalar, it would be easy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "apodization = Field(np.random.randn(pupil_grid.size) + 1j * np.random.randn(pupil_grid.size), pupil_grid)\n",
    "phase = np.angle(apodization)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But in case the apodization is a function, things are not that trivial. The `AgnosticOpticalElement` also provides a method for simplifying this operation. We will inherit from `Apodizer` and create a `SpecialApodizer` with this new method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SpecialApodizer(Apodizer):\n",
    "    def __init__(self, apodization):\n",
    "        Apodizer.__init__(self, apodization)\n",
    "    \n",
    "    @property\n",
    "    def phase(self):\n",
    "        return self.construct_function(np.angle, self.apodization)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `construct_function()` method of `AgnosticOpticalElement` returns a function with the encompassing signature. This means that if one of the arguments of the function depends on input grid, then the returned function will also be a function of input grid. If at least one of the arguments of the function is a function of wavelength, then the returned function will also be a function of wavelength. The non of the arguments were functions, then the supplied function is evaluated with the supplied arguments. \n",
    "\n",
    "The net effect is that, for this function, the `phase` will behave in the same way as `apodization` did."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "Implementing your own optical elements can be as easy as deriving from the `OpticalElement` class and overloading (= rewriting) the `forward()` and `backward()` functions. \n",
    "\n",
    "When arguments to your optical element depend on the grid or the wavelength of the propagated wavefront, it sometimes is easier to derive from `AgnosticOpticalElement` instead. This class provides several methods and a caching strategy to simplify otherwise complicated optical elements. The best way to really get an insight in the internals of HCIPy is to look in the source code. A good place to start is are the apodizers, `Apodizer`, `PhaseApodizer` and `SurfaceApodizer`. These classes provide an easy and uncomplicated starting point. If the user is familiar with polarization and Jones matrices, the `JonesMatrixOpticalElement` and `PhaseRetarder` classes are also good examples of relatively simple agnostic optical elements."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "level": "advanced",
  "thumbnail_figure_index": -1
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
